import numpy
import matplotlib.pyplot as pyplot
from scipy.interpolate import interp1d
from scipy.optimize import brenth
from matplotlib.patches import ConnectionPatch
#from scipy import constants

h = 6.63e-34 #Js
TotalRb = 1.0/2.0
TotalCs = 0.5
spin = 1.0/2.0
orbital = 0.0
Bohr_magneton = 9.274e-24
NuclearRb = 3.0/2.0
NuclearCs =  7.0/2.0

A_HF = 3.41734130545245e9 * h
B_HF = 1e-9#2.724e6 * h

Transtion_Freq = 7.60249e9

def Raising_operator(j):
	#produce the raising operator J+
	dimension = numpy.rint(2.0*j+1).astype(int)
	J = numpy.zeros((dimension,dimension))
	for m_j in range(numpy.rint(2.0*j).astype(int)):
		J[m_j,m_j+1]=numpy.sqrt(j*(j+1)-(j-m_j)*(j-m_j-1))
	return J

#produce the three generalised projections of angular momentum:
def X_operator(J_plus):
	J_minus = numpy.transpose(J_plus)
	return 0.5*(J_plus+J_minus)
def Y_operator(J_plus):
	J_minus = numpy.transpose(J_plus)
	return 0.5j*(J_minus - J_plus)
def Z_operator(J_plus):
	J_minus = numpy.transpose(J_plus)
	return 0.5*(numpy.dot(J_plus,J_minus)-numpy.dot(J_minus,J_plus))

def vector_dot(x,y):
	X_Y = numpy.zeros(x[0].shape,dtype=numpy.complex)
	for i in range(x.shape[0]):
		X_Y += numpy.dot(x[i],y[i])
	return X_Y

def H_dipole(I,J):
	I_J = vector_dot(I,J)
	return A_HF*I_J

def H_Quad(I,J,Nuclear,Total):
	denominator = 2*Nuclear*(2*Nuclear-1)*Total*(2*Total-1)

	I_J = vector_dot(I,J)
	J_J = vector_dot(J,J)
	I_I = vector_dot(I,I)

	numerator = 3*(numpy.dot(I_J,I_J))+3.0/2.0 *I_J - numpy.dot(I_I,J_J)
	if denominator !=0:
		return B_HF*numerator/denominator
	else:
		return numpy.zeros(numerator.shape)

def find_g(J,L,S):
	g = (J*(J+1)-L*(L+1)+S*(S+1))/(2*J*(J+1))
	return g+1

def H_zeeman(J,B,Total,orbital,spin):
	E_z = Bohr_magneton*B*J[2]*find_g(Total,orbital,spin)
	return E_z

J_plus = Raising_operator(TotalRb)
I_plus = Raising_operator(NuclearRb)

I_vec = numpy.array([numpy.kron(numpy.identity(int(2*TotalRb+1)),X_operator(I_plus)),numpy.kron(numpy.identity(int(2*TotalRb+1)),Y_operator(I_plus)),numpy.kron(numpy.identity(int((2*TotalRb+1))),Z_operator(I_plus))])
J_vec = numpy.array([numpy.kron(X_operator(J_plus),numpy.identity(int(2*NuclearRb+1))),numpy.kron(Y_operator(J_plus),numpy.identity(int(2*NuclearRb+1))),numpy.kron(Z_operator(J_plus),numpy.identity(int(2*NuclearRb+1)))])

#print H_dipole(I_vec,J_vec).shape
#print H_Quad(I_vec,J_vec).shape
Hamiltonian_zero_field_quad = H_dipole(I_vec,J_vec)+H_Quad(I_vec,J_vec,NuclearRb,TotalRb)
#Hamiltonian_zero_field = H_dipole(I_vec,J_vec)
#plt.matshow(Hamiltonian_zero_field.astype(float))

Energies,Transform =numpy.linalg.eig(Hamiltonian_zero_field_quad)
Energies = numpy.sort(Energies)
print('////////////////')
#print Transform
#plt.show()

B_array = numpy.linspace(1e-15,0.0500,10000)


shiftsArray = numpy.zeros((len(B_array),8))
for b in range(len(B_array)):

	Hamiltonian = Hamiltonian_zero_field_quad + H_zeeman(J_vec,B_array[b],TotalRb,orbital,spin)
	Energies,Transf = numpy.linalg.eig(Hamiltonian)
	#Diag_hamiltonian = numpy.dot(Transform.T,numpy.dot(Hamiltonian,Transform))
	#B_shifts = numpy.diag(Diag_hamiltonian)

	#shiftsArray = numpy.concatenate((shiftsArray,B_shifts), axis=1)
	shiftsArray[b,:]=numpy.sort(Energies)
	#for i in range(Transf.shape[0]):
		#shiftsArray[b,i] = numpy.dot(Transf[:,i],numpy.dot(Hamiltonian,Transf[:,i].T))

	#shiftsArray[b,:]=numpy.dot(Energies)
	#print Transf [:,0]

numpy.insert(B_array,0,0)
numpy.insert(shiftsArray,0,Energies,0)

fig,ax = pyplot.subplots(2,1,sharex=True)

for i in range(1,shiftsArray.shape[1]):
	if i<=2*(NuclearRb-TotalRb)+1:
		ax[1].plot(B_array*1e4,(shiftsArray[:,i]-shiftsArray[:,0])/(h*1e9),color='black')
	else:
		ax[0].plot(B_array*1e4,(shiftsArray[:,i]-shiftsArray[:,0])/(h*1e9),color='black')


ax[1].set_xlabel('Magnetic Field [Gauss]')
#ax[1].set_ylabel('Transition Frequency from |1,1>[GHz]')
fig.text(0.02,0.5,'Transition Frequency from |1,1>[GHz]',va='center',rotation='vertical')
ax[0].set_ylim([6.6,7.5])
ax[1].set_ylim([0,0.5])
ax[1].set_xlim([150,185])

ax[0].spines["bottom"].set_visible(False)
ax[1].spines["top"].set_visible(False)
ax[0].xaxis.tick_top()
ax[0].tick_params(labeltop='off')
ax[1].xaxis.tick_bottom()

d = .01  # how big to make the diagonal lines in axes coordinates
# arguments to pass to plot, just so we don't keep repeating them
kwargs = dict(transform=ax[0].transAxes, color='k', clip_on=False)
ax[0].plot((-d, +d), (-d, +d), **kwargs)        # top-left diagonal
ax[0].plot((1 - d, 1 + d), (-d, +d), **kwargs)  # top-right diagonal

kwargs.update(transform=ax[1].transAxes)  # switch to the bottom axes
ax[1].plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal
ax[1].plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  # bottom-right diagonal

F_B = interp1d(B_array*1e4,(shiftsArray[:,-1]-shiftsArray[:,0])/(h))
B= brenth(lambda x: F_B(x)-Transtion_Freq,320,380)
print(B)

ax[1].add_artist(ConnectionPatch(xyA=(B,0),xyB=(B,Transtion_Freq*1e-9),coordsA='data',coordsB='data',axesA=ax[1],axesB=ax[0],color='g',lw=2,arrowstyle="<->"))
fig.subplots_adjust(hspace=0.05)
ax[0].set_title("Rb $5^2S_{1/2}$")



##########Repeat for second atom#############

A_HF = 2.2981579425e9 * h
B_HF = 1e-9#2.724e6 * h

J_plus = Raising_operator(TotalCs)
I_plus = Raising_operator(NuclearCs)

I_vec = numpy.array([numpy.kron(numpy.identity(int(2*TotalCs+1)),X_operator(I_plus)),numpy.kron(numpy.identity(int(2*TotalCs+1)),Y_operator(I_plus)),numpy.kron(numpy.identity(int((2*TotalCs+1))),Z_operator(I_plus))])
J_vec = numpy.array([numpy.kron(X_operator(J_plus),numpy.identity(int(2*NuclearCs+1))),numpy.kron(Y_operator(J_plus),numpy.identity(int(2*NuclearCs+1))),numpy.kron(Z_operator(J_plus),numpy.identity(int(2*NuclearCs+1)))])

Hamiltonian_zero_field_quad = H_dipole(I_vec,J_vec)+H_Quad(I_vec,J_vec,NuclearCs,TotalCs)
#Hamiltonian_zero_field = H_dipole(I_vec,J_vec)

Energies,Transform =numpy.linalg.eig(Hamiltonian_zero_field_quad)
Energies = numpy.sort(Energies)
print( '////////////////')
#print Transform

B_array = numpy.linspace(1e-15,0.0500,10000)


shiftsArray = numpy.zeros((len(B_array),16))
for b in range(len(B_array)):

	Hamiltonian = Hamiltonian_zero_field_quad + H_zeeman(J_vec,B_array[b],TotalCs,orbital,spin)
	Energies,Transf = numpy.linalg.eig(Hamiltonian)
	#Diag_hamiltonian = numpy.dot(Transform.T,numpy.dot(Hamiltonian,Transform))
	#B_shifts = numpy.diag(Diag_hamiltonian)

	#shiftsArray = numpy.concatenate((shiftsArray,B_shifts), axis=1)
	shiftsArray[b,:]=numpy.sort(Energies)
	#for i in range(Transf.shape[0]):
		#shiftsArray[b,i] = numpy.dot(Transf[:,i],numpy.dot(Hamiltonian,Transf[:,i].T))

	#shiftsArray[b,:]=numpy.dot(Energies)
	#print Transf [:,0]

numpy.insert(B_array,0,0)
numpy.insert(shiftsArray,0,Energies,0)

fig,ax = pyplot.subplots(2,1,sharex=True)


for i in range(1,shiftsArray.shape[1]):
	if i<=2*(NuclearCs-TotalCs)+1:
		ax[1].plot(B_array*1e4,(shiftsArray[:,i]-shiftsArray[:,0])/(h*1e9),color='black')
	else:
		ax[0].plot(B_array*1e4,(shiftsArray[:,i]-shiftsArray[:,0])/(h*1e9),color='black')


ax[1].set_xlabel('Magnetic Field [Gauss]')
fig.text(0.02,0.5,'Transition Frequency from |3,3>[GHz]',va='center',rotation='vertical')
ax[0].set_ylim([9,9.9])
ax[1].set_ylim([0,0.5])
ax[1].set_xlim([150,185])

ax[0].spines["bottom"].set_visible(False)
ax[1].spines["top"].set_visible(False)
ax[0].xaxis.tick_top()
ax[0].tick_params(labeltop='off')
ax[1].xaxis.tick_bottom()

d = .01  # how big to make the diagonal lines in axes coordinates
# arguments to pass to plot, just so we don't keep repeating them
kwargs = dict(transform=ax[0].transAxes, color='k', clip_on=False)
ax[0].plot((-d, +d), (-d, +d), **kwargs)        # top-left diagonal
ax[0].plot((1 - d, 1 + d), (-d, +d), **kwargs)  # top-right diagonal

kwargs.update(transform=ax[1].transAxes)  # switch to the bottom axes
ax[1].plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal
ax[1].plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  # bottom-right diagonal

F_B = interp1d(B_array*1e4,(shiftsArray[:,-1]-shiftsArray[:,0])/(h))
print( F_B(B))
ax[1].add_artist(ConnectionPatch(xyA=(B,0),xyB=(B,F_B(B)*1e-9),coordsA='data',coordsB='data',axesA=ax[1],axesB=ax[0],color='r',lw=2,arrowstyle="<->"))
fig.subplots_adjust(hspace=0.05)
ax[0].set_title("Cs $6^2S_{1/2}$")
pyplot.tight_layout()
pyplot.show()
